import torch
from torch import nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.FC1 = nn.Linear(hidden_dim, hidden_dim)
        self.FC2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, po_z, pr_z):
        # po_z: (batch_size, hidden_dim)
        # pr_z: (batch_size, hidden_dim)
        
        po_z_transformed = self.FC1(po_z)  # 使用线性层进行变换
        pr_z = self.FC2(pr_z)
        
        a = torch.bmm(po_z_transformed.unsqueeze(2), pr_z.unsqueeze(1))  # 注意力权重计算
        A = torch.softmax(a, dim=-1)  # 注意力权重归一化
        out = torch.bmm(A, pr_z.unsqueeze(2)).squeeze(2)  # 加权求和
        
        return out, A